from functools import partial

import mltraq
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import shuffle


def load(run: mltraq.Run):
    # Load the IRIS dataset, taking care of shuffling the samples.
    # We use run.vars, accessible only within the execution of the runs.
    run.vars.X, run.vars.y = shuffle(*load_iris(return_X_y=True), random_state=run.params.seed)


def train_predict(run: mltraq.Run):
    # Instantiate and train classifier on 100 samples (50 random samples left for evaluation).
    model = run.params.classifier(random_state=run.params.seed).fit(run.vars.X[:100], run.vars.y[:100])

    # Track the classifier name on run.fields, persisted to database.
    run.fields.model_name = model.__class__.__name__

    # Use trained model to make predictions.
    run.vars.y_pred = model.predict(run.vars.X[100:])
    run.vars.y_true = run.vars.y[100:]


def evaluate(run: mltraq.Run):
    # Track accuracy score from previously determined predictions.
    run.fields.accuracy = accuracy_score(run.vars.y_true, run.vars.y_pred)


# Connect to the MLtraq session and craete an experiment.
session = mltraq.create_session()
experiment = session.create_experiment()

# Use a parameter grid to define the experiment's runs.
experiment.add_runs(
    classifier=[
        partial(DummyClassifier, strategy="most_frequent"),
        partial(LogisticRegression, max_iter=1000),
        DecisionTreeClassifier,
        RandomForestClassifier,
        partial(KMeans, n_clusters=3, n_init="auto"),
    ],
    seed=range(10),
)

# Execute experiment, running in parallel the step functions on each run.
experiment.execute(steps=[load, train_predict, evaluate])

# Stats on the experiment
print("Experiment:")
print(experiment)
print("\n--")

# A sample run
print("A random run:")
print(experiment.runs.first().fields)
print("\n--")

# Query the results and report the ML models leaderboard.
df_leaderboard = (
    experiment.runs.df().groupby("model_name").mean(numeric_only=True).sort_values(by="accuracy", ascending=False)
)
print("Leaderboard:")
print(df_leaderboard)
